from relatio.datasets import load_trump_data

split_sentences = load_trump_data("split_sentences")
srl_res = load_trump_data("srl_res")

from relatio.wrappers import build_narrative_model

narrative_model = build_narrative_model(
    srl_res=srl_res,
    sentences=split_sentences[1],
    embeddings_type="gensim_keyed_vectors",
    embeddings_path="glove-wiki-gigaword-300",
    n_clusters=[[100]],
    top_n_entities=100,
    stop_words = None,
    max_length = 50,
    remove_punctuation = True,
    remove_digits = True,
    remove_chars = "",
    lowercase = True,
    strip = True,
    remove_whitespaces = True,
    lemmatize = True,
    stem = False,
    tags_to_keep = None,
    remove_n_letter_words = 1,
    dimension_reduce_verbs = False,
    progress_bar=True,
)

print(narrative_model['entities'].most_common()[:20])

print(narrative_model['cluster_labels_most_freq'])

from relatio.wrappers import get_narratives

final_statements = get_narratives(
    srl_res=srl_res,
    doc_index=split_sentences[0],
    narrative_model=narrative_model,
    n_clusters=[0],
    progress_bar=True,
)

import numpy as np

final_statements['B-V_lowdim_with_neg'] = np.where(final_statements['B-ARGM-NEG_lowdim'] == True,
                                          'not-' + final_statements['B-V_lowdim'],
                                          final_statements['B-V_lowdim'])

final_statements['B-V_highdim_with_neg'] = np.where(final_statements['B-ARGM-NEG_highdim'] == True,
                                           'not-' + final_statements['B-V_lowdim'],
                                           final_statements['B-V_highdim'])

# Concatenate high-dimensional narratives (with text preprocessing but no clustering)

final_statements['narrative_highdim'] = (final_statements['ARG0_highdim'] + ' ' +
                                         final_statements['B-V_highdim_with_neg'] + ' ' +
                                         final_statements['ARG1_highdim'])

# Concatenate low-dimensional narratives (with clustering)

final_statements['narrative_lowdim'] = (final_statements['ARG0_lowdim'] + ' ' +
                                        final_statements['B-V_highdim_with_neg'] + ' ' +
                                        final_statements['ARG1_lowdim'])

# Focus on narratives with a ARG0-VERB-ARG1 structure (i.e. "complete narratives")

indexNames = final_statements[(final_statements['ARG0_lowdim'] == '')|
                             (final_statements['ARG1_lowdim'] == '')|
                             (final_statements['B-V_lowdim_with_neg'] == '')].index

complete_narratives = final_statements.drop(indexNames)

from relatio.graphs import build_graph, draw_graph

temp = complete_narratives[["ARG0_lowdim", "ARG1_lowdim", "B-V_lowdim"]]
temp.columns = ["ARG0", "ARG1", "B-V"]
temp = temp[(temp["ARG0"] != "") & (temp["ARG1"] != "") & (temp["B-V"] != "")]
temp = temp.groupby(["ARG0", "ARG1", "B-V"]).size().reset_index(name="weight")
temp = temp.sort_values(by="weight", ascending=False).iloc[
    0:50
]
temp = temp.to_dict(orient="records")

for l in temp:
    l["color"] = None

G = build_graph(
    dict_edges=temp, dict_args={}, edge_size=None, node_size=10, prune_network=True
)

draw_graph(G, notebook=True, output_filename="../graphs/Figure_J_1.html")
